Skip to content

Modular Diffusers Guiders #11311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: modular-refactor
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 14, 2025

The following methods are currently supported:

Note: PAG is implemented as Skip Layer Guidance and does not have its own guider implementation. The equivalent SLG initialization is:

from diffusers import SkipLayerGuidance, LayerSkipConfig

config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=False, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=5.0, skip_layer_guidance_scale=2.5, skip_layer_config=config)

Note: STG is also implemented as Skip Layer Guidance:

  • STG-r: skip_attention=False, skip_ff=True, skip_attention_scores=False
  • STG-a: skip_attention=True, skip_ff=False, skip_attention_scores=False
  • STG-t: skip_attention=True, skip_ff=True, skip_attention_scores=False
  • STG-v:skip_attention=False, skip_ff=False, skip_attention_scores=True (essentially PAG)

Note: You can use different SLG configs for different parts of the model. Create multiple configs and pass as a list to skip_layer_config

APG CFG CFGZ SLG (Skip Attention Scores, Skip FF) SLG (Skip Attention Scores) SLG (Skip Attention, Skip FF) SLG (Skip Attention) SLG (Skip FF)
Minimal all guiders testing script
from pathlib import Path

import torch
import torch.nn.functional as F
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.guiders import AdaptiveProjectedGuidance, AutoGuidance, CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance
from diffusers.hooks import LayerSkipConfig, SmoothedEnergyGuidanceConfig

output_dir = "dump_modular_diffusers"
Path(output_dir).mkdir(parents=True, exist_ok=True)

components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024



cfg = ClassifierFreeGuidance(guidance_scale=10.0, guidance_rescale=0.0, use_original_formulation=False, start=0.0, stop=1.0)
pipe.update_states(guider=cfg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfg.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=False, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=True, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention---skip_ff.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=True, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_ff.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=False, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention_scores.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=True, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention_scores---skip_ff.png")



apg = AdaptiveProjectedGuidance(guidance_scale=12.0, adaptive_projected_guidance_momentum=-0.5, adaptive_projected_guidance_rescale=10.0)
pipe.update_states(guider=apg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-apg.png")



# Should not set zero_init_steps > 0 for non-flow-matching schedulers
cfgz = ClassifierFreeZeroStarGuidance(guidance_scale=10.0, zero_init_steps=0)
pipe.update_states(guider=cfgz)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfgz.png")



configs = []
configs.append(LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=True, skip_attention_scores=False, dropout=0.1))
configs.append(LayerSkipConfig(indices=[0, 1], fqn="down_blocks.1.attentions.1.transformer_blocks", skip_attention=True, skip_ff=False, skip_attention_scores=False, dropout=0.05))
ag = AutoGuidance(guidance_scale=10.0, auto_guidance_config=configs, use_original_formulation=False, start=0.0, stop=1.0)
pipe.update_states(guider=ag)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-ag.png")



config = SmoothedEnergyGuidanceConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks")
seg = SmoothedEnergyGuidance(guidance_scale=7.5, seg_guidance_scale=2.5, seg_blur_sigma=9999999.0, seg_guidance_config=config)
pipe.update_states(guider=seg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-seg.png")


tcfg = TangentialClassifierFreeGuidance(guidance_scale=10.0, guidance_rescale=0.0, use_original_formulation=False, start=0.00, stop=1.0)
pipe.update_states(guider=tcfg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-tcfg.png")



cfgpp = CFGPlusPlusGuidance(guidance_scale=0.9, guidance_rescale=0.0, use_original_formulation=False, start=0.0, stop=1.0)
assert pipe.scheduler.__class__.__name__ == "EulerDiscreteScheduler"
pipe.update_states(guider=cfgpp)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfgpp.png")
YiYi's modified full test script
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
    AdaptiveProjectedGuidance,
    ClassifierFreeGuidance,
    SkipLayerGuidance,
    LayerSkipConfig,
)
from diffusers.utils import load_image
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLAutoPipeline, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "dump_modular_diffusers"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)      

auto_pipeline_block = StableDiffusionXLAutoPipeline()
auto_pipeline = ModularPipeline.from_block(auto_pipeline_block)
refiner_pipeline = ModularPipeline.from_block(auto_pipeline_block)



# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
components.add("controlnet", controlnet)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)


# load components/config into nodes
auto_pipeline.update_states(**components.components)


# load other componetns for swap later
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=3.0,
    skip_layer_config=LayerSkipConfig(
        indices=[2, 3, 7, 8],
        fqn="mid_block.attentions.0.transformer_blocks",
        skip_attention=False,
        skip_ff=False,
        skip_attention_scores=True,
    ),
    start=0.0,
    stop=1.0,
)
cfg_guider = ClassifierFreeGuidance(guidance_scale=5.0)


# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")


# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")

# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()


# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")

# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.unload_lora_weights()

auto_pipeline.update_states(guider=pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    ip_adapter_image=ip_adapter_image,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test  4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()

# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
# let's checkout the refiner_node
print(f" refiner_pipeline info")
print(refiner_pipeline)
print(f" ")

print(f" refiner_pipeline: triggered by `image_latents`")
print(refiner_pipeline.get_execution_blocks("image_latents"))
print(" ")

print(f" running test8: img2img with refiner")


generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = refiner_pipeline(
    image_latents=latents,  
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    image=init_image,
    height=1024,
    width=1024,
    mask_image=inpaint_mask,
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

auto_pipeline.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    padding_mask_crop=33, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()

# test13: apg

print(f" ")
print(f" running test13: apg")

apg_guider = AdaptiveProjectedGuidance(guidance_scale=15.0, adaptive_projected_guidance_momentum=-0.3, adaptive_projected_guidance_rescale=12.0, start=0.01)
auto_pipeline.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
  prompt=prompt, 
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), vae=components.get("vae_fix"), guider=pag_guider)
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    generator=generator, 
    control_mode=[3], 
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt, 
    height=1024, 
    width=1024, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    mask_image=inpaint_mask, 
    control_image=controlnet_union_image,
    control_mode=[3],
    height=1024, 
    width=1024, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w mentioned this pull request Apr 14, 2025
@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review April 14, 2025 15:04
@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu April 14, 2025 15:04
@a-r-r-o-w
Copy link
Member Author

@vladmandic's suggestion about having a universal start/stop parameter from here is now implemented too. Note, however, that the guiders should already support any kind of dynamic schedule with multiple enabling/disabling per inference if user modifies the properties on the guider object (see this comment for example).

Batched inference is still supported too! (in terms of multiple prompts and setting num_images_per_prompt > 1. It's just that it is not supported by batching conditional and unconditional branches together. This can be handled lazily eventually but I'm prioritizing implementing the methods to work first, before doing anything too complex/time consuming. We need to design in a way that caching methods would be compatible easily, and potentially other techniques that we couldn't support before too.

@@ -0,0 +1,271 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context on why we need this, see #10875 and this comment.

I discussed with Dhruv and for now we should keep it. After one of FBC or Guider PR is merged to main, I can do the refactor and make use of decorators. This will save me the burden of implementing the same thing in both PRs and maintaining it until one gets merged, but rest assured I'll do the refactor before next release


def _register_attention_processors_metadata():
# AttnProcessor2_0
AttentionProcessorRegistry.register(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, only this and BasicTransformerBlock is relevant, since modular diffusers only supports SDXL. The remaining is from copying but we keep it to avoid merge conflict since FirstBlockCache PR will most likely be merged before modular diffusers

return noise_cfg


def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is easier to work with if we:

  1. provide a default method here on guilder_utils.py to deal with a list of inputs like you specified here: each element could be a tensor or tuples/list of tensors - this logic should be mostly the same for different guiders, no?
  2. let each specific guider class to define how to prepare each input element

basically the method here become something like this, would this make sense?

def prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> Tuple[List[torch.Tensor], ...]:
    """
    Prepares the inputs for the denoiser by processing each argument individually using a helper method.
    """
    list_of_inputs = []
    
    for arg in args:
        if isinstance(arg, (tuple, list))
            if len(args) != 2:
                raise ValueError("...")
        elif not isinstance(arg, Torch.Tensor):
            raise ValueError("...")
        processed_input = self.prepare_input_single(arg, num_conditions)

        list_of_inputs.append(processed_input)
    
    return tuple(list_of_inputs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the implementations

added_cond_kwargs=data.added_cond_kwargs,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do something like this?

noise_pred_outputs = []
for batch_index, (...) in enumerate(zip(...):
    latents_i = ...
    noise_pred = pipeline.unet(..)
    noise_pred_outputs = self.guilder.prepare_and_add_output(pipeline.unet, noise_pred, noise_pred_outputs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.

I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think

Comment on lines 88 to 93
if self._is_cfgpp_enabled():
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
pred_cond = self._preds["pred_cond"]
pred_uncond = self._preds["pred_uncond"]
diff = pred_uncond - pred_cond
pred = pred + diff * self.guidance_scale * self._sigma_next
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original repository implements CFG++ in a different way. I wanted to try and make it work without really modifying all our schedulers, and so it's done this way. The math works out the same.

For context, in our schedulers, we do:

new_sample = sample + model_output_after_cfg * (sigmas[i + 1] - sigmas[i])
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_after_cfg * sigmas[i + 1]

What we need to do for CFG++ is this instead:

new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_uncond * sigmas[i + 1]

(This is only for EulerDiscreteScheduler and will differ for other schedulers)

After a little bit of working it out on paper, I found that some different schedulers don't really have to be modified if we add and subtract some terms after the scheduler step. We will need to have some specialized code (it can either exist in this file or the scheduler file) to add/subtract the right terms for each scheduler, so LMK how you think we should do it

Nevermind, it's better to just do this in the scheduler

return noise_cfg


def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the implementations

added_cond_kwargs=data.added_cond_kwargs,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.

I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think

@a-r-r-o-w
Copy link
Member Author

Also cc @DN6 for all the custom hook implementations

@a-r-r-o-w a-r-r-o-w requested a review from DN6 April 16, 2025 12:41
@@ -668,7 +675,38 @@ def step(
dt = self.sigmas[self.step_index + 1] - sigma_hat

prev_sample = sample + derivative * dt

if _use_cfgpp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are hoping to find a scalable solution that can provide maximium support for community creativity. It isn't scalable if it requires code change into schedulers.

I think it can be manipulated inside guider, no? since, we have all the variables in pipeline state and all the components in model states, which you can use to access scheduler and tbe sigmas counter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Euler, yes, it is easy add a correction term outside the scheduler step and make it work -- this is how it was originally implemented.

For DDIM, DPM++, and all the others, it quickly gets very complicated to handle all the correction terms correctly since you need to recalculate a lot of variables for the original model_output, subtract them out, calculate the correct variables using model_pred_uncond, add that in. I don't think that having specialized code in the guider to handle all usable schedulers, probably using isinstance checks, is a good approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, agree it should not be in guiders,
there is a dependency between guider & scheduler, our scheduler implementation are not aware of guidance approach since they were all designd to work with CFG
we can revist this last, but I would more lean towards making new schedulers for CFG++ since it basically requires a new step function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try to see if there's a way to breakdown existing schedulers into smaller methods, and consider the possiblity of overriding the behaviour given certain params from user. Fully re-implementing each CFG++ supported scheduler will probably just become combinatorial explosion hell. There is also a need to consider more techniques that come up, which might require tweaking just small aspects of the scheduler, and we should be able to make the experience of such integration better/easier ("Assemble like Lego" for modular diffusers)

Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree,

not sure about breaking existing schedulers into smaller methods, though, it is already pretty small. I think allowing overriding should be sufficient:) we currently already sort of allow override set_timesteps by accepting custom timesteps created by the users, but it is a bit hacky/not very nice.

we should find a way to support different set_timesteps & step methods very easily (maybe something similar to attention processor, but should be much simpler)

how about we only support one scheduler for CFG++ in this PR and we can do a refactor on scheduler follow-up?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i left a little bit more comments, for discussion only at this point. no need to do anything for now. let's find a design we are happy with first

@@ -668,7 +675,38 @@ def step(
dt = self.sigmas[self.step_index + 1] - sigma_hat

prev_sample = sample + derivative * dt

if _use_cfgpp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, agree it should not be in guiders,
there is a dependency between guider & scheduler, our scheduler implementation are not aware of guidance approach since they were all designd to work with CFG
we can revist this last, but I would more lean towards making new schedulers for CFG++ since it basically requires a new step function

if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try not to store tensors inside guider class, unless we have to
this can go into the guider_data if we decide to make one

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be possible. The guider decides what batches of input it processes - this may be a batch that results in either pred_cond, pred_uncond, pred_cond_skip, and so on. The guider will need to maintain this state information (i.e. which batch of data it is currently processing), but the modular pipeline can pull this info and maintain the output dict. If this sounds good, I'll update the implementations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! let's try:)

# prepare latents for controlnet using the guider
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)

Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe easier to put all the batched input into its own data class, something like this

pipeline.guider.set_input_fields(
    "latents" = "latents", 
    "prompt_embeds" = ("prompt_embeds", "negative_prompt_embeds"),
     # we can make a data field to indicate it is conditional or not
    "is_uncond" = (False, True)
     ...
)
# this should return a list, or tuple
batched_guider_data= pipeline.guider.prepare_inputs(data)

for batch in batched_guider_data:
   # instead of latents_i, we can access and update via batch.latents, which corresponding to guider_data[i].latents
    batch.latents = pipeline.scheduler.scale_model_input(batch.latents, t)
    ...
    added_cond_kwargs = {
        "text_embeds": batch.pooled_prompt_embeds,
        "time_ids": batch.add_time_ids,
     }
     ....
     if batch.is_uncond and data.guess_mode:
         down_block_res_samples = [torch.zeros_like(d) for d in down_block_res_samples]
     else:
         down_block_res_samples, mid_block_res_sample = pipeline.controlnet(batch.latents, ...)
     ...
     # each batch has its own model_output
     batch.noise_pred = pipeline.unet(batch.latents, ...)

# Perform guidance
# I think we can combine the guilder.prepare_outputs & guider forward pass
data.noise_pred = pipeline.guider(batched_guider_data, ...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# we can make a data field to indicate it is conditional or not

Since there has been no concrete use case of using more than a conditional and unconditional batch of data, I think we could simply enforce the convention that:

  • if one value is passed, it can either correspond to conditional data or unconditional data. Does not really matter which because the user controls what they are passin
  • if two values are passed, first value corresponds to conditional data and second value corresponds to unconditional data.

Do you know of an example where more than one of each cond/uncond is used, in order to have the is_uncond identifier?

# this should return a list, or tuple
batched_guider_data= pipeline.guider.prepare_inputs(data)

This sounds good to me. We provide all the input fields as available in data and pull values out of it when prepare_inputs is called within the guider. This way, guider has access to data and can further pull more information (such as excite tokens required in attend-and-excite)

# each batch has its own model_output

Sounds good to me. Each batch in batched_guider_data will need an associated ID for this to work. For example, for SLG, each batch will need to be marked either pred_cond, pred_uncond and pred_cond_skip (but either of these could be dynamically disabled/enabled by changing values with callback or directly on the guidance object)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there has been no concrete use case of using more than a conditional and unconditional batch of data, I think we could simply enforce the convention that:

if one value is passed, it can either correspond to conditional data or unconditional data. Does not really matter which because the user controls what they are passin
if two values are passed, first value corresponds to conditional data and second value corresponds to unconditional data.

sounds good to me!

Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another thing I have in mind and will try play around a little bit this week is to create a
LoopSequentialPipeineBlocks that's similar to SequentialPipelineBlocks (

class SequentialPipelineBlocks:
) but comes with the loop and default guider behavior

this way the denoising loop itself will be modular too, e.g. you can just add controlnet/inpaint into your denoise block instead of rewrite a new denoisingBlock that does these things ; and this way we do not have to support additional callbacks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also, I won't add more commit to the refactor PR until this PR is merged, so don't worry about merge conflicts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants